from matplotlib import pyplot as plt
from sklearn.manifold import TSNE


def tsne_visualization(reps, labels, task_id, args):
    print("================tsne visualization start================")
    x_np = reps
    y_np = labels

    tsne = TSNE(n_components=2, random_state=0)
    x_2d = tsne.fit_transform(x_np)

    plt.figure(figsize=(6, 5))
    plt.scatter(x_2d[:, 0], x_2d[:, 1], c=y_np, s=1)
    plt.colorbar()
    plt.title('t-SNE visualization of the data')
    plt.show()
    print("save tsne figure")
    plt.savefig(args.output_dir + '/tsne{}.png'.format(str(task_id)), dpi=120)
    print("================tsne visualization finish================")

